"""
Phase 3b: r-g Heterogeneity — OADR Spline, Safe Issuer Split, GFC Break
=========================================================================
The pooled r-g null (Phase 3) is a composition artifact.  This script
decomposes the heterogeneity along three dimensions:
  1. OADR threshold (piecewise-linear spline)
  2. Safe-issuer status (AA- or above)
  3. Pre/post GFC structural break

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
        safe_assets/data/processed/safe_asset_panel.csv  (for safe_issuer)
Output: fiscal_dominance/output/tables/phase3b_rg_heterogeneity.csv
        fiscal_dominance/output/tables/table5b_rg_spline.md
        fiscal_dominance/output/tables/table5c_safe_split.md
        fiscal_dominance/output/tables/table5d_gfc_split.md
        fiscal_dominance/output/tables/table5e_decomposition.md
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"  N={model.n_obs:,}, {model.n_countries} countries, "
          f"R²={model.r_squared:.4f}, rho={model.rho:.3f}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    result_df['n_obs'] = model.n_obs
    result_df['n_countries'] = model.n_countries
    result_df['r_squared'] = model.r_squared
    result_df['rho'] = model.rho
    return model, result_df


def fmt(x, decimals=1):
    """Format number with sign."""
    return f"{x:+.{decimals}f}"


def stars(p):
    """Significance stars."""
    if p < 0.001:
        return "***"
    elif p < 0.01:
        return "**"
    elif p < 0.05:
        return "*"
    elif p < 0.1:
        return "†"
    return ""


def write_markdown_table(path, title, headers, rows, notes=None):
    """Write a markdown table to file."""
    lines = [f"### {title}", ""]
    lines.append("| " + " | ".join(headers) + " |")
    lines.append("|" + "|".join(["--:" if i > 0 else ":--" for i in range(len(headers))]) + "|")
    for row in rows:
        lines.append("| " + " | ".join(str(c) for c in row) + " |")
    if notes:
        lines.append("")
        lines.append(f"*{notes}*")
    lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")
    print(f"  Saved: {path}")


def extract_coef(model, feature_names, var):
    """Extract coefficient, SE, and p-value for a variable."""
    idx = feature_names.index(var)
    return model.beta[idx], model.se[idx], model.pvalues[idx]


def main():
    print("=" * 70)
    print("PHASE 3b: r-g Heterogeneity — Spline, Safe Split, GFC Break")
    print("=" * 70)

    # ------------------------------------------------------------------
    # Load data
    # ------------------------------------------------------------------
    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Fiscal panel: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Merge safe_issuer from safe assets panel
    safe_path = PROJECT_DIR / "safe_assets" / "data" / "processed" / "safe_asset_panel.csv"
    if safe_path.exists():
        safe_df = pd.read_csv(safe_path, usecols=['iso3', 'year', 'safe_issuer'])
        n_before = len(df)
        df = df.merge(safe_df, on=['iso3', 'year'], how='left')
        df['safe_issuer'] = df['safe_issuer'].fillna(0).astype(int)
        print(f"Merged safe_issuer: {df['safe_issuer'].sum():,} safe country-years "
              f"({df[df['safe_issuer']==1]['iso3'].nunique()} countries)")
    else:
        print(f"WARNING: {safe_path} not found — creating safe_issuer from high_income proxy")
        df['safe_issuer'] = 0

    all_results = []
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    controls = [c for c in controls if c in df.columns]

    # ==================================================================
    # SECTION 1: OADR SPLINE ON NOMINAL r-g (full sample)
    # ==================================================================
    print("\n\n" + "=" * 70)
    print("SECTION 1: OADR SPLINE REGRESSIONS ON NOMINAL r-g")
    print("=" * 70)

    spline_table_rows = []

    # Create proper piecewise-linear spline terms (continuous, not dummies)
    for knot in [15, 20, 25, 30]:
        df[f'oadr_spline_{knot}'] = np.maximum(0, df['old_dep'] - knot / 100.0)

    for knot in [15, 20, 25, 30]:
        knot_var = f'oadr_spline_{knot}'

        dep_var = 'r_minus_g'
        vars_s = ['old_dep', knot_var] + controls
        est = df.dropna(subset=[dep_var] + vars_s).copy()

        if len(est) >= 200:
            label = f"1.{knot}: OADR spline (knot={knot}%) -> nominal r-g"
            m, r = fit_and_report(
                est[dep_var].values, est[vars_s].values,
                est['iso3'].values, est['year'].values,
                vars_s, label
            )
            all_results.append(r)

            c_base, se_base, p_base = extract_coef(m, vars_s, 'old_dep')
            c_spline, se_spline, p_spline = extract_coef(m, vars_s, knot_var)

            print(f"\n  >>> Knot={knot}%: base slope = {c_base:.1f} (p={p_base:.4f}), "
                  f"spline = {c_spline:.1f} (p={p_spline:.4f})")
            if p_spline < 0.1:
                crossover = -c_base / c_spline if abs(c_spline) > 1e-6 else float('nan')
                if 0 < crossover < 1:
                    print(f"      Crossover OADR: {crossover*100:.1f}%")

            spline_table_rows.append([
                f"{knot}%",
                f"{fmt(c_base, 1)} ({p_base:.3f}){stars(p_base)}",
                f"{fmt(c_spline, 1)} ({p_spline:.4f}){stars(p_spline)}",
                f"{m.n_obs:,}", f"{m.n_countries}",
                f"{m.r_squared:.3f}"
            ])

    # Write Table 5b
    write_markdown_table(
        TABLE_DIR / "table5b_rg_spline.md",
        "Table 5b: OADR Spline on Nominal r-g (Full Sample)",
        ["Knot", "OADR (base slope)", "Spline (above knot)", "N", "Countries", "R²"],
        spline_table_rows,
        notes="PanelGLS with AR(1). Controls: fiscal balance, KAOPEN, relative productivity, NFA."
    )

    # ==================================================================
    # SECTION 2: SAFE vs NON-SAFE ISSUER SPLIT
    # ==================================================================
    print("\n\n" + "=" * 70)
    print("SECTION 2: SAFE vs NON-SAFE ISSUER SPLIT")
    print("=" * 70)

    safe_table_rows = []

    for safe_label, safe_val in [("Safe issuers (AA- or above)", 1),
                                  ("Non-safe issuers", 0)]:
        sub = df[df['safe_issuer'] == safe_val].copy()

        # 2a. OADR on nominal r-g
        dep_var = 'r_minus_g'
        vars_oadr = ['old_dep'] + controls
        est = sub.dropna(subset=[dep_var] + vars_oadr)
        if len(est) >= 50:
            label = f"2a. OADR -> nominal r-g ({safe_label})"
            m, r = fit_and_report(
                est[dep_var].values, est[vars_oadr].values,
                est['iso3'].values, est['year'].values,
                vars_oadr, label
            )
            all_results.append(r)
            c, se, p = extract_coef(m, vars_oadr, 'old_dep')
            safe_table_rows.append([
                safe_label, "OADR → nominal r-g",
                f"{fmt(c, 1)} ({p:.4f}){stars(p)}",
                f"{m.n_obs:,}", f"{m.n_countries}", f"{m.r_squared:.3f}"
            ])

        # 2b. Z polynomial on nominal r-g
        vars_z = demo_vars + controls
        est = sub.dropna(subset=[dep_var] + vars_z)
        if len(est) >= 50:
            label = f"2b. Z -> nominal r-g ({safe_label})"
            m, r = fit_and_report(
                est[dep_var].values, est[vars_z].values,
                est['iso3'].values, est['year'].values,
                vars_z, label
            )
            all_results.append(r)
            c, se, p = extract_coef(m, vars_z, 'Z_1')
            safe_table_rows.append([
                safe_label, "Z₁ → nominal r-g",
                f"{fmt(c, 1)} ({p:.4f}){stars(p)}",
                f"{m.n_obs:,}", f"{m.n_countries}", f"{m.r_squared:.3f}"
            ])

        # 2c. OADR on real r-g
        dep_var_real = 'r_minus_g_real'
        if dep_var_real in sub.columns:
            est = sub.dropna(subset=[dep_var_real] + vars_oadr)
            if len(est) >= 50:
                label = f"2c. OADR -> real r-g ({safe_label})"
                m, r = fit_and_report(
                    est[dep_var_real].values, est[vars_oadr].values,
                    est['iso3'].values, est['year'].values,
                    vars_oadr, label
                )
                all_results.append(r)
                c, se, p = extract_coef(m, vars_oadr, 'old_dep')
                safe_table_rows.append([
                    safe_label, "OADR → real r-g",
                    f"{fmt(c, 1)} ({p:.4f}){stars(p)}",
                    f"{m.n_obs:,}", f"{m.n_countries}", f"{m.r_squared:.3f}"
                ])

    # Write Table 5c
    write_markdown_table(
        TABLE_DIR / "table5c_safe_split.md",
        "Table 5c: Safe vs Non-Safe Issuer Split on r-g",
        ["Sample", "Specification", "Coefficient (p-value)", "N", "Countries", "R²"],
        safe_table_rows,
        notes="Safe issuer = S&P rating AA- or above. PanelGLS with AR(1)."
    )

    # ==================================================================
    # SECTION 3: SAFE ISSUER REAL r-g SPLINE
    # ==================================================================
    print("\n\n" + "=" * 70)
    print("SECTION 3: SAFE ISSUER REAL r-g SPLINE")
    print("=" * 70)

    safe_sub = df[df['safe_issuer'] == 1].copy()

    for knot in [20, 25, 30]:
        knot_var = f'oadr_spline_{knot}'
        safe_sub[knot_var] = np.maximum(0, safe_sub['old_dep'] - knot / 100.0)

        dep_var = 'r_minus_g_real'
        vars_s = ['old_dep', f'oadr_spline_{knot}'] + controls
        est = safe_sub.dropna(subset=[dep_var] + vars_s)

        if len(est) >= 50:
            label = f"3.{knot}: Safe issuer OADR spline (knot={knot}%) -> real r-g"
            m, r = fit_and_report(
                est[dep_var].values, est[vars_s].values,
                est['iso3'].values, est['year'].values,
                vars_s, label
            )
            all_results.append(r)

    # ==================================================================
    # SECTION 4: PRE/POST GFC SPLIT
    # ==================================================================
    print("\n\n" + "=" * 70)
    print("SECTION 4: PRE/POST GFC SPLIT")
    print("=" * 70)

    gfc_table_rows = []

    for period_label, year_range in [('Pre-GFC (1990-2007)', (1990, 2007)),
                                      ('Post-GFC (2010-2024)', (2010, 2024))]:
        period_df = df[(df['year'] >= year_range[0]) &
                       (df['year'] <= year_range[1])].copy()

        for dep_var, dep_label in [('r_minus_g_real', 'Real r-g'),
                                    ('r_minus_g', 'Nominal r-g')]:
            if dep_var not in period_df.columns:
                continue
            vars_gfc = demo_vars + controls
            est = period_df.dropna(subset=[dep_var] + vars_gfc)

            if len(est) >= 100:
                label = f"4. Z -> {dep_label} ({period_label})"
                m, r = fit_and_report(
                    est[dep_var].values, est[vars_gfc].values,
                    est['iso3'].values, est['year'].values,
                    vars_gfc, label
                )
                all_results.append(r)
                c, se, p = extract_coef(m, vars_gfc, 'Z_1')
                gfc_table_rows.append([
                    period_label, dep_label,
                    f"{fmt(c, 1)} ({p:.4f}){stars(p)}",
                    f"{m.n_obs:,}", f"{m.n_countries}", f"{m.r_squared:.3f}"
                ])

    # Write Table 5d
    write_markdown_table(
        TABLE_DIR / "table5d_gfc_split.md",
        "Table 5d: Pre/Post GFC Split on r-g",
        ["Period", "Dependent Var", "Z₁ Coefficient (p-value)", "N", "Countries", "R²"],
        gfc_table_rows,
        notes="Pre-GFC: 1990–2007. Post-GFC: 2010–2024. PanelGLS with AR(1)."
    )

    # ==================================================================
    # SECTION 5: COMPONENT DECOMPOSITION (23-country bond yield sample)
    # ==================================================================
    print("\n\n" + "=" * 70)
    print("SECTION 5: COMPONENT DECOMPOSITION (bond yield sample)")
    print("=" * 70)

    decomp_table_rows = []

    # Bond yield sample: countries with govt_bond_10y data
    bond_sub = df.dropna(subset=['govt_bond_10y']).copy()
    print(f"  Bond yield sample: {len(bond_sub):,} obs, "
          f"{bond_sub['iso3'].nunique()} countries")

    for dep_var, dep_label in [('govt_bond_10y', 'Nominal bond yield'),
                                ('real_bond_10y', 'Real bond yield'),
                                ('rgdp_growth', 'Real GDP growth'),
                                ('inflation', 'Inflation')]:
        if dep_var not in bond_sub.columns:
            continue
        vars_d = demo_vars + controls
        est = bond_sub.dropna(subset=[dep_var] + vars_d)

        if len(est) >= 50:
            label = f"5. Z -> {dep_label} (bond yield sample)"
            m, r = fit_and_report(
                est[dep_var].values, est[vars_d].values,
                est['iso3'].values, est['year'].values,
                vars_d, label
            )
            all_results.append(r)
            c, se, p = extract_coef(m, vars_d, 'Z_1')
            decomp_table_rows.append([
                dep_label,
                f"{fmt(c, 1)} ({p:.4f}){stars(p)}",
                f"{m.n_obs:,}", f"{m.n_countries}", f"{m.r_squared:.3f}"
            ])

    # Also: r-g on the bond yield sample
    for dep_var, dep_label in [('r_minus_g', 'r-g (nominal)'),
                                ('r_minus_g_real', 'r-g (real)')]:
        if dep_var not in bond_sub.columns:
            continue
        vars_d = demo_vars + controls
        est = bond_sub.dropna(subset=[dep_var] + vars_d)
        if len(est) >= 50:
            label = f"5. Z -> {dep_label} (bond yield sample)"
            m, r = fit_and_report(
                est[dep_var].values, est[vars_d].values,
                est['iso3'].values, est['year'].values,
                vars_d, label
            )
            all_results.append(r)
            c, se, p = extract_coef(m, vars_d, 'Z_1')
            decomp_table_rows.append([
                dep_label,
                f"{fmt(c, 1)} ({p:.4f}){stars(p)}",
                f"{m.n_obs:,}", f"{m.n_countries}", f"{m.r_squared:.3f}"
            ])

    # Write Table 5e
    write_markdown_table(
        TABLE_DIR / "table5e_decomposition.md",
        "Table 5e: Component Decomposition (Bond Yield Sample)",
        ["Dependent Variable", "Z₁ Coefficient (p-value)", "N", "Countries", "R²"],
        decomp_table_rows,
        notes="Homogeneous sample of countries with 10-year government bond yields. PanelGLS with AR(1)."
    )

    # ==================================================================
    # Save all results
    # ==================================================================
    if all_results:
        results_df = pd.concat(all_results, ignore_index=True)
        results_df.to_csv(TABLE_DIR / "phase3b_rg_heterogeneity.csv", index=False)
        print(f"\n{'=' * 70}")
        print(f"Saved: {TABLE_DIR / 'phase3b_rg_heterogeneity.csv'}")
        print(f"  {len(results_df)} rows across {results_df['model'].nunique()} models")

        # Summary table
        print(f"\n{'=' * 70}")
        print("KEY RESULTS SUMMARY")
        print("=" * 70)
        z1_rows = results_df[results_df['variable'].isin(['Z_1', 'old_dep'])][
            ['model', 'variable', 'coefficient', 'std_error', 'p_value',
             'r_squared', 'n_obs', 'n_countries']
        ]
        print(z1_rows.to_string(index=False, float_format='%.4f'))

    return results_df if all_results else pd.DataFrame()


if __name__ == "__main__":
    results = main()
